{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/home/liy0f/.conda/envs/tf_newest_py2/lib/python2.7/site-packages/keras_applications/resnet50.py:263: UserWarning: The output shape of `ResNet50(include_top=False)` has been changed since Keras 2.2.0.\n",
      "  warnings.warn('The output shape of `ResNet50(include_top=False)` '\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "from keras.applications import VGG16\n",
    "import keras\n",
    "from keras import backend as K\n",
    "# vgg_conv = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n",
    "resnet = keras.applications.resnet50.ResNet50(include_top=False, \n",
    "                                              weights='imagenet', input_shape=(224, 224, 3))\n",
    "K.set_learning_phase(1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Freeze the layers except the last 4 layers\n",
    "for layer in resnet.layers[:-4]:\n",
    "    layer.trainable = False\n",
    " \n",
    "# Check the trainable status of the individual layers\n",
    "for layer in resnet.layers:\n",
    "#     print(layer, layer.trainable)\n",
    "    if layer.name.startswith('bn'):\n",
    "        layer.call(layer.input, training=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras import models\n",
    "from keras import layers\n",
    "from keras import optimizers\n",
    " \n",
    "# Create the model\n",
    "model = models.Sequential()\n",
    " \n",
    "# Add the vgg convolutional base model\n",
    "model.add(resnet)\n",
    " \n",
    "# Add new layers\n",
    "\n",
    "model.add(layers.Flatten())\n",
    "# model.add(layers.Dense(128, activation='relu'))\n",
    "model.add(layers.Dropout(0.5))\n",
    "model.add(layers.Dense(2, activation='softmax'))\n",
    " \n",
    "# Show a summary of the model. Check the number of trainable parameters\n",
    "# model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.compile(loss='categorical_crossentropy',\n",
    "              optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# train_data_dir = 'data/train'\n",
    "# validation_data_dir = 'data/validation'\n",
    "train_data_dir = 'chest_xray/train'\n",
    "validation_data_dir = 'chest_xray/test'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "epochs = 10\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "img_height, img_width = 224,224"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 5232 images belonging to 2 classes.\n",
      "Found 624 images belonging to 2 classes.\n"
     ]
    }
   ],
   "source": [
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "\n",
    "train_datagen = ImageDataGenerator(\n",
    "    rescale=1. / 255)\n",
    "#     shear_range=0.2,\n",
    "#     zoom_range=0.2,\n",
    "#     horizontal_flip=True)\n",
    "\n",
    "test_datagen = ImageDataGenerator(rescale=1. / 255)\n",
    "\n",
    "train_generator = train_datagen.flow_from_directory(\n",
    "    train_data_dir,\n",
    "    target_size=(img_height, img_width),\n",
    "    batch_size=batch_size,\n",
    "    class_mode='categorical')\n",
    "\n",
    "validation_generator = test_datagen.flow_from_directory(\n",
    "    validation_data_dir,\n",
    "    target_size=(img_height, img_width),\n",
    "    batch_size=batch_size,\n",
    "    class_mode='categorical')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "82/82 [==============================] - 87s 1s/step - loss: 0.3458 - acc: 0.9287 - val_loss: 0.6679 - val_acc: 0.8574\n",
      "Epoch 2/10\n",
      "82/82 [==============================] - 85s 1s/step - loss: 0.0732 - acc: 0.9775 - val_loss: 0.8486 - val_acc: 0.8510\n",
      "Epoch 3/10\n",
      "82/82 [==============================] - 88s 1s/step - loss: 0.0347 - acc: 0.9886 - val_loss: 0.9541 - val_acc: 0.8510\n",
      "Epoch 4/10\n",
      "82/82 [==============================] - 88s 1s/step - loss: 0.0198 - acc: 0.9928 - val_loss: 0.8518 - val_acc: 0.8638\n",
      "Epoch 5/10\n",
      "82/82 [==============================] - 86s 1s/step - loss: 0.0202 - acc: 0.9931 - val_loss: 0.8435 - val_acc: 0.8638\n",
      "Epoch 6/10\n",
      "82/82 [==============================] - 89s 1s/step - loss: 0.0148 - acc: 0.9939 - val_loss: 1.0645 - val_acc: 0.8413\n",
      "Epoch 7/10\n",
      "82/82 [==============================] - 88s 1s/step - loss: 0.0102 - acc: 0.9960 - val_loss: 0.8923 - val_acc: 0.8526\n",
      "Epoch 8/10\n",
      "82/82 [==============================] - 86s 1s/step - loss: 0.0115 - acc: 0.9956 - val_loss: 0.6337 - val_acc: 0.8750\n",
      "Epoch 9/10\n",
      "82/82 [==============================] - 84s 1s/step - loss: 0.0076 - acc: 0.9977 - val_loss: 0.7872 - val_acc: 0.8702\n",
      "Epoch 10/10\n",
      "82/82 [==============================] - 85s 1s/step - loss: 0.0074 - acc: 0.9975 - val_loss: 0.7217 - val_acc: 0.8766\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7fc3c4464d90>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit_generator(\n",
    "    train_generator,\n",
    "    epochs=epochs,\n",
    "    validation_data=validation_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
